iT邦幫忙

2022 iThome 鐵人賽

DAY 9
0
AI & Data

JAX 好好玩系列 第 9

JAX 好好玩 (9) : JAX.NUMPY (5) : DeviceArray 初探

  • 分享至 

  • xImage
  •  

(本貼文所列出的程式碼,皆以 colab 筆記本方式執行,可由此下載

之前曾經提到,DeviceArray 是 JAX 自行定義的陣列類別,定義在 jax.numpy.DeviceArray,它的角色等同於 Numpy 中的 ndarray,現在就讓我們來更進一步的認識這個類別 [9.1]。

產生 DeviceArray

我們通常不需要直接宣告一個 DeviceArray 物件 (object, 或者也可以稱之為案例 instance),許多 jax.numpy 的 API 可以協助我們產生所需的 DeviceArray。例如:

jax.numpy.append(arr, values, axis=None)
jax.numpy.arange(start, stop=None, step=None, dtype=None)
jax.numpy.array(object, dtype=None, copy=True, order='K', ndmin=0)
jax.numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0)
jax.numpy.ones(shape, dtype=None)
jax.numpy.zeros(shape, dtype=None)

這些 API 在 Numpy 皆有對應的函式,語法和用法幾乎完全一樣,差別在於 Numpy 回傳的是 ndarray ,而 JAX 回傳 DeviceArray 。值得一提的是 jax.numpy.array() ,我們可以用它來直接將 ndarray 轉換為 DeviceArray。

# create numpy ndarray
x_np = np.arange(10)
print(f'type of x_np: {type(x_np)}')
print(f'value of x_np:{x_np}')

# convert to DeviceArray
x_jnp = jnp.array(x_np)
print(f'type of x_jnp: {type(x_jnp)}')
print(f'value of x_jnp: {x_jnp}')

output:
type of x_np: <class 'numpy.ndarray'>
value of x_np:[0 1 2 3 4 5 6 7 8 9]
type of x_jnp: <class 'jaxlib.xla_extension.DeviceArray'>
value of x_jnp: [0 1 2 3 4 5 6 7 8 9]

當我們使用 python 的 type() 來檢查 DeviceArray 型別的變數時,回傳的是:

<class 'jaxlib.xla_extension.DeviceArray'>

這是 DeviceArray 在 JAX Python 庫 “jaxlib” 實際的位置,不過為了方便,JAX 提供了別名 (alias) jax.numpy.DeviceArray ,讓大家使用。

# jax.numpy.DeviceArray is an alias of jaxlib.xla_extension.DeviceArray
isinstance(x_jnp, jnp.DeviceArray)

output:
True

DeviceArray 是不可變類別 (Immutable Class)

不可變 (immutability) 是 jax.numpy 和 Numpy 最大的差異之一!初學 JAX 的讀者,要非常注意這個部份。

在 Numpy 中我們可習慣使用以下的程式段來更改陣列元素的值:

# numpy is mutible
x = np.arange(10)
print(f'Before assignment : {x}')
x[0] = 10
print(f'After assignment : {x}')

output:
Before assignment : [0 1 2 3 4 5 6 7 8 9]
After assignment : [10 1 2 3 4 5 6 7 8 9]

然而同樣的方式,在 JAX 的 DeviceArray 上則會造成執行時錯誤。

# JAX/DeviceArray is immutible
x_jnp = jnp.arange(10)
print(f'Before assignment : {x_jnp}')
x_jnp[0] = 10

output:
Before assignment : [0 1 2 3 4 5 6 7 8 9]


*TypeError: '<class 'jaxlib.xla_extension.DeviceArray'>' object does not support item assignment. JAX arrays are immutable. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html *

因為 DeviceArray 是不可變的,所以上面程式的變數 x_jnp 在被創造出來之後,其值就不可以改變。 JAX 提供的折衷方法,是利用 x.at[idx].set(y) 這種方式:

# For updating individual elements, JAX provides an indexed update syntax that returns
# an updated copy:
x = jnp.arange(10)
y = x.at[0].set(10)
print(f'x: {x}')
print(f'y: {y}')

output:
x: [0 1 2 3 4 5 6 7 8 9]
y: [10 1 2 3 4 5 6 7 8 9]

要注意 x.at[0].set(10) 這個敍述式 (expression) 並不會更改 x 的值,它是將 x 複製一份,在副本上修改索引 0 的值為 10。所以下面這個常用的敍述:

x = x.at[0].set(10)

在執行之後,變數 x 其實已經參考到不同的記憶體位址了。我們可以檢驗看看:

# .at[].set() copy the DeviceArray.
x2 = jnp.arange(5.0)
print(f'Before at.set: {id(x2)}')
x2 = x2.at[0].set(9.9)
print(f'After at.set : {id(x2)}')
print(x2)

output:
Before at.set: 139712279962800
After at.set : 139712279963568
[9.9 1. 2. 3. 4. ]

DeviceArray.at[ ] 也可以指定索引範圍:

x2.at[2:4].set(88.88)

output:
DeviceArray([ 9.9 , 1. , 88.88, 88.88, 4. ], dtype=float32)

除了 set () 之外,DeviceArray.at[ ] 還提供了其他的運算,下表是 JAX 官方文件 [9.1] 所列出的操作,以及其對應的 Numpy (In-place) 語法,供大家參考:
https://ithelp.ithome.com.tw/upload/images/20220919/20129616XX0KvB9FmO.png

用 at[].add() 舉個例子:

jax_array = jnp.ones((5, 6))
print("original array:")
print(jax_array)

new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)

output:
original array:
[[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]
[1. 1. 1. 1. 1. 1.]
[1. 1. 1. 8. 8. 8.]]

註:

[9.1] 可以參考 JAX 官方文件 jax.numpy package


上一篇
JAX 好好玩 (8) : JAX.NUMPY (4) : 用了才知道它的快
下一篇
JAX 好好玩 (10) : JAX.NUMPY (6) : 超過範圍的索引
系列文
JAX 好好玩40
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言